{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b02083d2-55a3-41e6-aed5-8ff633af6a59",
   "metadata": {},
   "source": [
    "## Get started with SageMaker\n",
    "In this notebook you'll learn how SageMaker can be used to:\n",
    "\n",
    "1. Preprocess (and optionally explore) a dataset\n",
    "2. Train an XGBoost classifier for customer churn prediction, using a managed job with SageMaker Training, using a managed image.\n",
    "3. Perform hyperparameter tuning to find optimal set of hyperparameters, using a managed job with SageMaker HyperParameter Tuning\n",
    "5. Perform batch inference using a managed SageMaker Batch Transform job.\n",
    "7. Create a managed real-time SageMaker endpoint.\n",
    "\n",
    "All SageMaker resources are created using the SageMaker Core SDK. You can find more information about sagemaker-core [here](https://sagemaker-core.readthedocs.io/en/latest/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0489c209-9628-4ddf-a849-8b973234126a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade pip -q\n",
    "%pip install sagemaker-core -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "314d959f-8c4a-4996-b167-24958d53417a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from sagemaker.core.helper.session_helper import Session, get_execution_role\n",
    "\n",
    "# Set up region, role and bucket parameters used throughout the notebook.\n",
    "sagemaker_session = Session()\n",
    "region = sagemaker_session.boto_region_name\n",
    "role = get_execution_role()\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "\n",
    "print(f\"AWS region: {region}\")\n",
    "print(f\"Execution role: {role}\")\n",
    "print(f\"Default S3 bucket: {bucket}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93b62649-4d89-4b6e-9f18-a48aa0bb27a4",
   "metadata": {},
   "source": [
    "## Preprocess dataset\n",
    "We'll use a synthetic dataset that AWS provides for customer churn prediction.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f347ca27-68a3-41a1-998c-434d39f8245a",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-info\">\n",
    "<b>NOTE:</b> This sample doesn't perform any exploratory data anlysis since how to preprocess the dataset is already known.\n",
    "    \n",
    "If you're interested in how to perform exploratory analysis, there's a section in the documentation for the sagemaker-python-sdk available that explores the dataset, [here](https://sagemaker-examples.readthedocs.io/en/latest/introduction_to_applying_machine_learning/xgboost_customer_churn/xgboost_customer_churn.html).\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38f74911-14c9-4ef8-bd18-db9664de7dcf",
   "metadata": {},
   "source": [
    "#### Read the data from S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c16f79-f58f-412f-8716-c62179a270e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import StringIO\n",
    "import pandas as pd\n",
    "\n",
    "data = sagemaker_session.read_s3_file(\n",
    "    f\"sagemaker-example-files-prod-{region}\",\n",
    "    \"datasets/tabular/synthetic/churn.txt\"\n",
    ")\n",
    "\n",
    "df = pd.read_csv(StringIO(data))\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e00d8db-e6c8-4233-8437-8c2076895416",
   "metadata": {},
   "source": [
    "#### Apply processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de3d40e5-6d29-443e-98b1-8629ce106cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# Phone number is unique - will not add value to classifier\n",
    "df = df.drop(\"Phone\", axis=1)\n",
    "\n",
    "# Cast Area Code to non-numeric\n",
    "df[\"Area Code\"] = df[\"Area Code\"].astype(object)\n",
    "\n",
    "# Remove one feature from highly corelated pairs\n",
    "df = df.drop([\"Day Charge\", \"Eve Charge\", \"Night Charge\", \"Intl Charge\"], axis=1)\n",
    "\n",
    "# One-hot encode catagorical features into numeric features\n",
    "model_data = pd.get_dummies(df) \n",
    "model_data = pd.concat(\n",
    "    [model_data[\"Churn?_True.\"], model_data.drop([\"Churn?_False.\", \"Churn?_True.\"], axis=1)], axis=1\n",
    ")\n",
    "model_data = model_data.astype(float)\n",
    "\n",
    "# Split data into train and validation datasets\n",
    "train_data, validation_data = train_test_split(\n",
    "    model_data, test_size=0.33, random_state=42)\n",
    "\n",
    "# Further split the validation dataset into test and validation datasets.\n",
    "validation_data, test_data = train_test_split(\n",
    "    validation_data, test_size=0.33, random_state=42)\n",
    "\n",
    "# Remove and store the target column for the test data. This is used for calculating performance metrics after training, on unseen data.\n",
    "test_target_column = test_data['Churn?_True.']\n",
    "test_data.drop(['Churn?_True.'], axis=1, inplace=True)\n",
    "\n",
    "# Store all datasets locally\n",
    "train_data.to_csv(\"train.csv\", header=False, index=False)\n",
    "validation_data.to_csv(\"validation.csv\", header=False, index=False)\n",
    "test_data.to_csv(\"test.csv\", header=False, index=False)\n",
    "\n",
    "# Upload each dataset to S3\n",
    "s3_train_input = sagemaker_session.upload_data('train.csv', bucket)\n",
    "s3_validation_input = sagemaker_session.upload_data('validation.csv', bucket)\n",
    "s3_test_input = sagemaker_session.upload_data('test.csv', bucket)\n",
    "\n",
    "print('Datasets uploaded to:')\n",
    "print(s3_train_input)\n",
    "print(s3_validation_input)\n",
    "print(s3_test_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d8c4139-244d-4827-b872-35fbeed664d8",
   "metadata": {},
   "source": [
    "## Train a classifier using XGBoost\n",
    "Use SageMaker Training and the managed XGBoost image to train a classifier. <br />\n",
    "More details on how to use SageMaker managed training with XGBoost can be found [here](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbbf270c-20c0-4f82-a8a6-c2794ba57f70",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-info\">\n",
    "  <b>NOTE:</b> For more information on using SageMaker managed container images and retrieving their ECR paths, \n",
    "  <a href=\"https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths.html\" target=\"_blank\">here</a> \n",
    "  is the documentation. Please note that the image URI might need to be updated based on your selected AWS region.\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "268583f2-2c20-401b-a57c-ee88fd402583",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = '141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-xgboost:1.7-1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e85a91b5-4a68-43f4-b04a-991035cc35bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.resources import TrainingJob\n",
    "from sagemaker.core.shapes import AlgorithmSpecification, Channel, DataSource, S3DataSource, ResourceConfig, StoppingCondition, OutputDataConfig\n",
    "\n",
    "job_name = 'xgboost-churn-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())  # Name of training job\n",
    "instance_type = 'ml.m4.xlarge'  # SageMaker instance type to use for training\n",
    "instance_count = 1  # Number of instances to use for training\n",
    "volume_size_in_gb = 30  # Amount of storage to allocate to training job\n",
    "max_runtime_in_seconds = 600  # Maximum runtimt. Job exits if it doesn't finish before this\n",
    "s3_output_path = f\"s3://{bucket}\"  # bucket and optional prefix where the training job stores output artifacts, like model artifact.\n",
    "\n",
    "# Specify hyperparameters\n",
    "hyper_parameters = {\n",
    "    \"max_depth\": \"5\",\n",
    "    \"eta\": \"0.2\",\n",
    "    \"gamma\": \"4\",\n",
    "    \"min_child_weight\": \"6\",\n",
    "    \"subsample\": \"0.8\",\n",
    "    \"verbosity\": \"0\",\n",
    "    \"objective\": \"binary:logistic\",\n",
    "    \"num_round\": \"100\",\n",
    "}\n",
    "\n",
    "# Create training job.\n",
    "training_job = TrainingJob.create(\n",
    "    training_job_name=job_name,\n",
    "    hyper_parameters=hyper_parameters,\n",
    "    algorithm_specification=AlgorithmSpecification(\n",
    "        training_image=image,\n",
    "        training_input_mode='File'\n",
    "    ),\n",
    "    role_arn=role,\n",
    "    input_data_config=[\n",
    "        Channel(\n",
    "            channel_name='train',\n",
    "            content_type='csv',\n",
    "            data_source=DataSource(\n",
    "                s3_data_source=S3DataSource(\n",
    "                    s3_data_type='S3Prefix',\n",
    "                    s3_uri=s3_train_input,\n",
    "                    s3_data_distribution_type='FullyReplicated'\n",
    "                )\n",
    "            )\n",
    "        ),\n",
    "        Channel(\n",
    "            channel_name='validation',\n",
    "            content_type='csv',\n",
    "            data_source=DataSource(\n",
    "                s3_data_source=S3DataSource(\n",
    "                    s3_data_type='S3Prefix',\n",
    "                    s3_uri=s3_validation_input,\n",
    "                    s3_data_distribution_type='FullyReplicated'\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "    ],\n",
    "    output_data_config=OutputDataConfig(\n",
    "        s3_output_path=s3_output_path\n",
    "    ),\n",
    "    resource_config=ResourceConfig(\n",
    "        instance_type=instance_type,\n",
    "        instance_count=instance_count,\n",
    "        volume_size_in_gb=volume_size_in_gb\n",
    "    ),\n",
    "    stopping_condition=StoppingCondition(\n",
    "        max_runtime_in_seconds=max_runtime_in_seconds\n",
    "    )\n",
    ")\n",
    "\n",
    "# Wait for the training job to complete\n",
    "training_job.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8db8e0b6-d000-453a-baf9-6b7c164d1caf",
   "metadata": {},
   "source": [
    "## Hyperparameter tuning\n",
    "If the optimal hyperparameters aren't known, we perform a SageMaker Hyperparameter Tuning job, which runs several training jobs and iteratively finds the best set of parameters.\n",
    "\n",
    "From a high level, a tuning job constists of 2 main components:\n",
    "- `HyperParameterTrainingJobDefinition`, which specifies details for each individidual training job, like image to use, input channels, resource configuration etc.\n",
    "- `HyperParameterTuningJobConfig`, which details the tuning configuration, like what strategy to use, how many jobs to run and what parameters to tune etc.\n",
    "\n",
    "You can find more information about how it works [here](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02419934-5cd4-4fda-9706-c852545e4ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.resources import HyperParameterTuningJob\n",
    "from sagemaker.core.shapes import HyperParameterTuningJobConfig, \\\n",
    "     ResourceLimits, HyperParameterTuningJobWarmStartConfig, ParameterRanges, AutoParameter, \\\n",
    "     Autotune, HyperParameterTrainingJobDefinition, HyperParameterTuningJobObjective, HyperParameterAlgorithmSpecification, \\\n",
    "     OutputDataConfig, StoppingCondition, ResourceConfig\n",
    "\n",
    "tuning_job_name = 'xgboost-tune-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())  # Name of tuning job\n",
    "\n",
    "max_number_of_training_jobs = 50  # Maximum number of training jobs to run as part of the tuning job.\n",
    "max_parallel_training_jobs = 5  # Maximum number of parallell training.\n",
    "max_runtime_in_seconds = 3600  # Maximum runtime for tuning job.\n",
    "\n",
    "# Create HyperParameterTrainingJobDefinition object, containing information about each individual training job.\n",
    "hyper_parameter_training_job_defintion = HyperParameterTrainingJobDefinition(\n",
    "        role_arn=role,\n",
    "        algorithm_specification=HyperParameterAlgorithmSpecification(\n",
    "            training_image=image,\n",
    "            training_input_mode='File'\n",
    "        ),\n",
    "        input_data_config=[\n",
    "            Channel(\n",
    "                channel_name='train',\n",
    "                content_type='csv',\n",
    "                data_source=DataSource(\n",
    "                    s3_data_source=S3DataSource(\n",
    "                        s3_data_type='S3Prefix',\n",
    "                        s3_uri=s3_train_input,\n",
    "                        s3_data_distribution_type='FullyReplicated'\n",
    "                    )\n",
    "                )\n",
    "            ),\n",
    "            Channel(\n",
    "                channel_name='validation',\n",
    "                content_type='csv',\n",
    "                data_source=DataSource(\n",
    "                    s3_data_source=S3DataSource(\n",
    "                        s3_data_type='S3Prefix',\n",
    "                        s3_uri=s3_validation_input,\n",
    "                        s3_data_distribution_type='FullyReplicated'\n",
    "                    )\n",
    "                )\n",
    "            )\n",
    "        ],\n",
    "        output_data_config=OutputDataConfig(\n",
    "            s3_output_path=s3_output_path\n",
    "        ),\n",
    "        stopping_condition=StoppingCondition(\n",
    "            max_runtime_in_seconds=max_runtime_in_seconds\n",
    "        ),\n",
    "        resource_config=ResourceConfig(\n",
    "            instance_type=instance_type,\n",
    "            instance_count=instance_count,\n",
    "            volume_size_in_gb=volume_size_in_gb,\n",
    "        )\n",
    "    )\n",
    "\n",
    "# Create HyperParameterTrainingJobDefinition object, containing information about the tuning job\n",
    "tuning_job_config = HyperParameterTuningJobConfig(\n",
    "        strategy='Bayesian',\n",
    "        hyper_parameter_tuning_job_objective=HyperParameterTuningJobObjective(\n",
    "            type='Maximize',\n",
    "            metric_name='validation:auc'\n",
    "        ),\n",
    "        resource_limits=ResourceLimits(\n",
    "            max_number_of_training_jobs=max_number_of_training_jobs,\n",
    "            max_parallel_training_jobs=max_parallel_training_jobs,\n",
    "            max_runtime_in_seconds=3600\n",
    "        ),\n",
    "        training_job_early_stopping_type='Auto',\n",
    "        parameter_ranges=ParameterRanges(\n",
    "            auto_parameters=[\n",
    "                AutoParameter(\n",
    "                    name='max_depth',\n",
    "                    value_hint='5'\n",
    "                ),\n",
    "                AutoParameter(\n",
    "                    name='eta',\n",
    "                    value_hint='0.1'\n",
    "                ),\n",
    "                AutoParameter(\n",
    "                    name='gamma',\n",
    "                    value_hint='8'\n",
    "                ),\n",
    "                AutoParameter(\n",
    "                    name='min_child_weight',\n",
    "                    value_hint='2'\n",
    "                ),\n",
    "                AutoParameter(\n",
    "                    name='subsample',\n",
    "                    value_hint='0.5'\n",
    "                ),\n",
    "                AutoParameter(\n",
    "                    name='num_round',\n",
    "                    value_hint='50'\n",
    "                )\n",
    "            ]\n",
    "        )\n",
    "    )\n",
    "\n",
    "# Create the tuning job using the 2 configuration objects above\n",
    "tuning_job = HyperParameterTuningJob.create(\n",
    "    hyper_parameter_tuning_job_name=tuning_job_name,\n",
    "    autotune=Autotune(\n",
    "        mode='Enabled'\n",
    "    ),\n",
    "    training_job_definition=hyper_parameter_training_job_defintion,\n",
    "    hyper_parameter_tuning_job_config=tuning_job_config\n",
    ")\n",
    "\n",
    "tuning_job.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b79e6746-6a03-4868-a22e-63a32bfcf599",
   "metadata": {},
   "source": [
    "## Use model artifacts for batch inference\n",
    "To use the model to perform batch inference, we can use a SageMaker Batch Transform job. The Transform Job requires a SageMaker model object, which contains information about what image and model to use.\n",
    "\n",
    "Below, we:\n",
    "1. Create a SageMaker model with the same first-party image as we used for training, and the model artifacts produced during training. Indeed, such image can also be used to run inference\n",
    "2. Use that SagMaker model with a Transform Job to perform batch inference with our test dataset\n",
    "3. Compute some performance metrics\n",
    "\n",
    "More information about SageMaker Batch Transform can be found [here](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82b0db2b-4daf-4ee0-a035-b7108f3d7912",
   "metadata": {},
   "source": [
    "#### Create SageMaker Model\n",
    "\n",
    "Create a Model resource based on the model artifacts produced by the best training job run by through hyperparameter tuning. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270aa263-2a8f-4225-bcd6-56a57195ac63",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.resources import Model\n",
    "from sagemaker.core.shapes import ContainerDefinition\n",
    "\n",
    "#model_s3_uri = training_job.model_artifacts.s3_model_artifacts  # Get URI of model artifacts from the training job.\n",
    "model_s3_uri = TrainingJob.get(tuning_job.best_training_job.training_job_name).model_artifacts.s3_model_artifacts # Get URI of model artifacts of the best model from the tuning job.\n",
    "\n",
    "\n",
    "# Create SageMaker model: An image along with the model artifact to use.\n",
    "customer_churn_model = Model.create(\n",
    "    model_name='customer-churn-xgboost',\n",
    "    primary_container=ContainerDefinition(\n",
    "        image=image,\n",
    "        model_data_url=model_s3_uri\n",
    "    ),\n",
    "    execution_role_arn=role\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7151c8e0-c0d8-49c9-9bda-28d0fe816b59",
   "metadata": {},
   "source": [
    "#### Create Transform Job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d15f53e5-8cc4-487b-bd64-fe8efff11517",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.resources import TransformJob\n",
    "from sagemaker.core.shapes import TransformInput, TransformDataSource, TransformS3DataSource, TransformOutput, TransformResources\n",
    "\n",
    "model_name = customer_churn_model.get_name()\n",
    "transform_job_name = 'churn-prediction' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())  # Name of TranformJob\n",
    "s3_output_path = f\"s3://{bucket}/transform\"  # bucket and optional prefix where the TranformJob stores the result.\n",
    "instance_type = 'ml.m4.xlarge'  # SageMaker instance type to use for TranformJob\n",
    "instance_count = 1  # Number of instances to use for TranformJob\n",
    "\n",
    "# Create Transform Job.\n",
    "transform_job = TransformJob.create(\n",
    "    transform_job_name=transform_job_name,\n",
    "    model_name=model_name,\n",
    "    transform_input=TransformInput(\n",
    "        data_source=TransformDataSource(\n",
    "            s3_data_source=TransformS3DataSource(\n",
    "                s3_data_type=\"S3Prefix\",\n",
    "                s3_uri=s3_test_input\n",
    "            )\n",
    "        ),\n",
    "        content_type=\"text/csv\"\n",
    "    ),\n",
    "    transform_output=TransformOutput(\n",
    "        s3_output_path=s3_output_path\n",
    "    ),\n",
    "    transform_resources=TransformResources(\n",
    "        instance_type=instance_type,\n",
    "        instance_count=instance_count\n",
    "    )\n",
    ")\n",
    "\n",
    "transform_job.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5eca0ceb-b23b-4d72-9d39-dc1a295af068",
   "metadata": {},
   "source": [
    "#### Compute performance metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b55cb28-ed28-4a96-b96e-05ac8944f1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score\n",
    "\n",
    "# A Transform Job uploads the results to a given output path in S3, with the name of the input file, with \".out\" added at the end. \n",
    "output_file_name = transform_job.transform_input.data_source.s3_data_source.s3_uri.split('/')[-1] + '.out'  # Get output file name\n",
    "output_s3_uri = f\"{transform_job.transform_output.s3_output_path}/{output_file_name}\"  # Create output S3 URI\n",
    "\n",
    "def split_s3_path(s3_path):\n",
    "    '''Lightweight method for extracting bucket and object key from S3 uri'''\n",
    "    path_parts = s3_path.replace(\"s3://\", \"\").split(\"/\")\n",
    "    bucket = path_parts.pop(0)\n",
    "    key = \"/\".join(path_parts)\n",
    "    return bucket, key\n",
    "\n",
    "def print_performance_metrics(probs, y, threshold = 0.5):\n",
    "    '''Lightweight method for printing performance metrics'''\n",
    "    \n",
    "    predictions = (probs >= threshold).astype(int)\n",
    "\n",
    "    # Compare predictions with the stored target\n",
    "    accuracy = accuracy_score(y, predictions)\n",
    "    precision = precision_score(y, predictions)\n",
    "    recall = recall_score(y, predictions)\n",
    "    roc_auc = roc_auc_score(y, probs)\n",
    "\n",
    "    print(f\"Accuracy: {accuracy}\")\n",
    "    print(f\"Precision: {precision}\")\n",
    "    print(f\"Recall: {recall}\")\n",
    "    print(f\"ROC AUC: {roc_auc}\")\n",
    "\n",
    "\n",
    "# Extract bucket and key separately from uri\n",
    "res_bucket, res_key = split_s3_path(output_s3_uri)\n",
    "\n",
    "# Download Transform Job results\n",
    "transform_job_result = sagemaker_session.read_s3_file(res_bucket, res_key)\n",
    "transform_job_result = pd.read_csv(StringIO(transform_job_result), header=None)\n",
    "\n",
    "print_performance_metrics(transform_job_result, test_target_column)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abfe8ade-8de6-487a-9b24-a2afbeaa8559",
   "metadata": {},
   "source": [
    "## Create SageMaker endpoint for real-time inference\n",
    "To create a SageMaker endpoint we first create an `EndpointConfig`. The endpoint configuration specifies what SageMaker model to use, and what endpoint type. We then use the `EndpointConfig` together with other optional parameters to create a SageMaker Endpoint.\n",
    "\n",
    "More information about SageMaker Endpoints can be found [here](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "056e57bb-9edf-461e-97e8-40c9e5ebf39f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.core.resources import Endpoint, EndpointConfig\n",
    "from sagemaker.core.shapes import ProductionVariant\n",
    "\n",
    "endpoint_config_name = 'churn-prediction-endpoint-config'  # Name of endpoint configuration\n",
    "model_name = customer_churn_model.get_name()  # Get name of SageMaker model created in previous step\n",
    "endpoint_name = \"customer-churn-endpoint\"  # Name of SageMaker endpoint\n",
    "\n",
    "endpoint_config = EndpointConfig.create(\n",
    "    endpoint_config_name=endpoint_config_name,\n",
    "    production_variants=[\n",
    "        ProductionVariant(\n",
    "            variant_name='AllTraffic',\n",
    "            model_name=model_name,\n",
    "            instance_type=instance_type,\n",
    "            initial_instance_count=1\n",
    "        )\n",
    "    ]\n",
    ")\n",
    "\n",
    "sagemaker_endpoint = Endpoint.create(\n",
    "    endpoint_name=endpoint_name,\n",
    "    endpoint_config_name=endpoint_config.get_name(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "635cb615-f4e0-43fa-91d6-0b480a51fd16",
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_endpoint.wait_for_status(target_status='InService')  # Wait for endpoint to become in service"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd718d54-65ee-4b67-8556-196246356a68",
   "metadata": {},
   "source": [
    "#### Test live endpoint - with one sample\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fdbd048-014f-418d-8d88-fc329afc7ee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract one sample payload and convert to string\n",
    "sample = test_data.sample(1)\n",
    "sample_payload = sample.to_csv(header=False, index=False).strip()\n",
    "\n",
    "# Send sample payload to live endpoint and parse response\n",
    "res = sagemaker_endpoint.invoke(body=sample_payload, content_type=\"text/csv\")\n",
    "result = res['Body'].read().decode('utf-8')\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ff78e4f-47cd-4f05-ba00-6ded32ad12a8",
   "metadata": {},
   "source": [
    "#### Test live endpoint - with entire test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d66fd20a-7016-4b55-9144-633461059b8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert entire test dataset to CSV string\n",
    "sample_payload = test_data.to_csv(header=False, index=False).strip()\n",
    "\n",
    "# Send sample payload to live endpoint and parse response\n",
    "res = sagemaker_endpoint.invoke(body=sample_payload, content_type=\"text/csv\")\n",
    "result = res['Body'].read().decode('utf-8')\n",
    "result = result.split('\\n')[:-1]\n",
    "\n",
    "# Compute performance metrics\n",
    "df_result = pd.DataFrame(result).astype(float)\n",
    "print_performance_metrics(df_result, test_target_column)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e9d3ad7-4b67-405e-940a-0331fd28811f",
   "metadata": {},
   "source": [
    "## Clean up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fb90949-9b35-4e11-8d2a-6053beacb4ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_endpoint.delete()\n",
    "endpoint_config.delete()\n",
    "customer_churn_model.delete()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
